4.1 병렬 스캔(Parallel Scan)을 통한 학습 효율성 확보
1. 서론: 시퀀스 모델링의 근본적 병목과 패러다임의 전환
인공지능, 특히 시퀀스 모델링(Sequence Modeling) 분야의 역사는 ’문맥(Context)’을 어떻게 효율적으로 압축하고 복원할 것인가에 대한 투쟁의 기록이다. 순환 신경망(RNN)은 메모리를 적게 사용하고 추론 속도가 빠르다는 장점에도 불구하고, 학습 시 시간 차원(Time Dimension)을 따라 순차적으로 계산해야 한다는 구조적 제약으로 인해 현대 하드웨어의 병렬 처리 능력을 온전히 활용하지 못했다. 반면, 트랜스포머(Transformer)는 어텐션(Attention) 메커니즘을 통해 모든 토큰 간의 상호작용을 병렬로 처리함으로써 학습 속도를 비약적으로 향상시켰으나, 시퀀스 길이 L에 대해 O(L^2)이라는 2차(Quadratic) 복잡도를 가짐으로써 긴 문맥 처리에 치명적인 한계를 드러냈다.1
Mamba 아키텍처는 이러한 이분법적 딜레마를 해결하기 위해 등장했다. Mamba의 핵심인 **선택적 상태 공간 모델(Selective State Space Model, Selective SSM)**은 입력에 따라 시스템의 파라미터가 동적으로 변화하는 시변(Time-Variant) 특성을 가진다. 이는 모델이 불필요한 정보를 망각하고 중요한 정보를 선별적으로 기억할 수 있게 하는 강력한 기능이지만, 동시에 기존 S4(Structured State Spaces) 모델들이 누렸던 ’컨볼루션(Convolution)을 통한 병렬 학습’의 이점을 포기하게 만든다.3 시변 시스템에서는 커널이 고정되지 않으므로 고속 푸리에 변환(FFT)을 적용할 수 없기 때문이다.
따라서 Mamba가 트랜스포머를 대체할 차세대 아키텍처로 자리 잡기 위해서는, 순차적 재귀(Recurrence) 구조를 유지하면서도 트랜스포머 수준의 병렬 학습 속도를 달성해야 하는 난제를 해결해야 한다. 본 장에서는 Mamba가 이 불가능해 보이는 과제를 어떻게 병렬 스캔(Parallel Scan) 알고리즘과 하드웨어 인식(Hardware-Aware) 설계를 통해 해결했는지 심층적으로 분석한다. 이는 단순한 알고리즘의 적용을 넘어, 수학적 연산자의 재정의와 GPU 메모리 계층 구조에 대한 극한의 최적화가 결합된 시스템 엔지니어링의 정수이다.
2. 순차적 재귀(Sequential Recurrence)의 수학적 구조와 한계
2.1 선형 재귀(Linear Recurrence)의 정의
Mamba의 기본 골격이 되는 이산 시간(Discrete-Time) SSM은 다음과 같은 재귀 식(Recurrence Formulation)으로 표현된다.3
h_t = \bar{A}_t h_{t-1} + \bar{B}_t x_t
y_t = C_t h_t
여기서:
- t는 시간 단계(Time step)를 의미하며 t \in {1, \dots, L}이다.
- h_t \in \mathbb{R}^N은 잠재 상태(Latent State) 벡터이다.
- x_t \in \mathbb{R}^D는 입력 벡터, y_t \in \mathbb{R}^D는 출력 벡터이다.
- \bar{A}_t \in \mathbb{R}^{N \times N}와 \bar{B}_t \in \mathbb{R}^{N \times D}는 이산화(Discretization)된 시스템 행렬이다.
일반적인 RNN(예: LSTM, GRU)은 상태 업데이트 식에 \tanh나 \sigma와 같은 비선형 활성화 함수가 포함되어 있어, h_{t-1}이 완전히 계산되어야만 h_t를 계산할 수 있는 강력한 직렬 의존성(Serial Dependency)을 가진다. 그러나 Mamba의 SSM은 상태 전이 과정이 **선형(Linear)**이라는 결정적인 차이점을 가진다.1 비선형성은 상태 전이 자체가 아닌, 구조적인 블록(Block) 단위나 출력 단계에서 적용된다. 이 ’선형성’은 병렬화가 불가능해 보이는 순차적 과정을 병렬 연산으로 변환할 수 있는 수학적 단초를 제공한다.
2.2 선택적 메커니즘과 컨볼루션의 불가능성
기존의 S4와 같은 선형 시불변(LTI) SSM은 \bar{A}_t = \bar{A}, \bar{B}_t = \bar{B}로 고정되어 있었다. 이 경우 전체 시퀀스의 출력 y는 입력 x와 글로벌 커널 \bar{K}의 컨볼루션으로 표현될 수 있다 (y = x * \bar{K}). 이는 FFT를 통해 O(L \log L) 복잡도로 매우 효율적으로 계산된다.6
하지만 Mamba는 모델의 표현력을 극대화하기 위해 ‘선택(Selection)’ 메커니즘을 도입했다. 즉, \bar{A}_t와 \bar{B}_t가 현재 입력 x_t의 함수가 된다 (\bar{A}_t(x_t), \bar{B}_t(x_t)). 이는 매 시점마다 시스템의 역학(Dynamics)이 달라짐을 의미하므로, 단일 커널을 사용하는 컨볼루션 연산이 불가능하다.7 표면적으로 이는 O(L)의 순차적 계산, 즉 느린 학습 속도로의 회귀를 의미하는 것처럼 보인다.
2.3 GPU 하드웨어와 순차적 연산의 부조화
현대 딥러닝 학습의 주력 하드웨어인 GPU는 수천 개의 코어를 사용하여 대량의 병렬 연산을 수행하는 데 최적화되어 있다. 순차적 연산은 다음과 같은 이유로 GPU 효율성을 심각하게 저하시킨다.8
- 가동률(Utilization) 저하: 한 번에 하나의 시간 단계만 계산하므로, GPU의 수많은 코어 중 극히 일부만 사용된다.
- 메모리 지연(Latency): 매 단계마다 결과를 메모리에 쓰고, 다음 단계에서 다시 읽어와야 하므로 메모리 접근 지연이 누적된다.
- 커널 실행 오버헤드: 짧은 연산을 반복적으로 수행하기 위해 수많은 커널을 실행(Launch)해야 하며, 이에 따른 오버헤드가 실제 연산 시간보다 커질 수 있다.
따라서 Mamba가 실용적인 학습 속도를 확보하기 위해서는 순차적 의존성을 유지하면서도 연산을 병렬화할 수 있는 알고리즘적 돌파구가 필수적이었다.
3. 병렬 결합 스캔(Parallel Associative Scan) 알고리즘
Mamba는 재귀적 연산을 병렬 결합 스캔(Parallel Associative Scan), 혹은 접두어 합(Prefix Sum) 문제로 재해석하여 해결했다. 이는 1990년 Blelloch가 제안한 알고리즘을 딥러닝 맥락에 맞게 재구성한 것이다.1
3.1 결합 법칙(Associativity)의 중요성
병렬 스캔이 가능하기 위한 유일하고도 가장 중요한 조건은 연산자 \bullet가 **결합 법칙(Associativity)**을 만족해야 한다는 것이다.1
(a \bullet b) \bullet c = a \bullet (b \bullet c)
결합 법칙이 성립하면 연산의 순서를 임의로 재그룹화(Regrouping)할 수 있다. 예를 들어, (x_1 \bullet x_2) \bullet (x_3 \bullet x_4)와 같이 인접한 요소들을 병렬로 묶어 계산한 뒤, 그 결과들을 다시 결합하는 계층적(Hierarchical) 접근이 가능해진다.
3.2 SSM을 위한 이항 연산자(Binary Operator) 정의
Mamba의 상태 업데이트 식 h_t = \bar{A}_t h_{t-1} + \bar{B}_t x_t를 병렬 스캔 형태로 변환하기 위해, 우리는 각 시간 단계의 파라미터를 하나의 요소로 보고, 두 개의 연속된 시간 구간을 결합하는 연산자를 정의해야 한다.
각 시간 t에서의 요소를 u_t = (\bar{A}_t, \bar{B}_t x_t)라고 정의하자. 편의상 v_t = \bar{B}_t x_t라 하면, u_t = (\bar{A}_t, v_t)가 된다. 두 상태 u_i = (A_i, v_i)와 u_j = (A_j, v_j) (단, i < j)를 결합하는 연산자 \bullet는 다음과 같이 정의된다.1
(A_j, v_j) \bullet (A_i, v_i) = (A_j A_i, A_j v_i + v_j)
이 연산의 물리적 의미는 다음과 같다:
- A_j A_i (상태 전이의 합성): 시간 i에서 j까지 시스템이 전이될 때의 누적된 변환 행렬이다. i에서 j-1을 거쳐 j로 가는 경로의 효과를 곱셈으로 연결한다.
- A_j v_i + v_j (입력의 누적): 시간 i의 입력(v_i)이 시간 j로 넘어오면서 A_j만큼 변환되고, 여기에 시간 j의 새로운 입력(v_j)이 더해진 값이다.
이 연산자 \bullet가 결합 법칙을 만족함은 행렬 곱셈과 덧셈의 결합 법칙에 의해 증명된다. 따라서 전체 시퀀스에 대한 상태 h_t는 초기 상태 h_0와 입력 시퀀스 u_{1 \dots t}에 대해 접두어 스캔(Prefix Scan)을 수행함으로써 병렬로 계산될 수 있다.
3.3 Blelloch 스캔 알고리즘의 동작 원리
Mamba는 Blelloch 알고리즘을 사용하여 O(L)의 작업량(Work)과 O(\log L)의 깊이(Depth)로 스캔을 수행한다. 이 알고리즘은 크게 두 단계(Pass)로 구성된다.1
- Up-Sweep (Reduce Phase):
- 인접한 두 요소를 짝지어 연산자 \bullet를 적용한다.
- L개의 요소는 L/2개로 줄어들고, 이를 재귀적으로 반복하여 트리의 루트(Root)까지 올라간다.
- 이 과정은 시퀀스의 전체 요약 정보를 계층적으로 생성한다. 총 \log_2 L 단계가 소요된다.
- Down-Sweep (Distribute Phase):
- 루트에서부터 다시 아래로 내려오면서, 각 노드에 저장된 요약 정보와 부모 노드로부터 전달받은 값을 결합한다.
- 최종적으로 리프(Leaf) 노드에 도달하면 모든 위치 t에 대한 누적 값(Prefix Sum), 즉 잠재 상태 h_t가 완성된다.
- 이 역시 \log_2 L 단계가 소요된다.
이러한 트리 기반 접근 방식 덕분에, Mamba는 시퀀스 길이가 100만(1M) 토큰에 달하더라도 학습 시간을 선형적으로 유지할 수 있다. 이는 트랜스포머의 어텐션 메커니즘이 O(1) 깊이(Depth)를 가지지만 O(L^2)의 연산량(Work)을 요구하는 것과 대비되어, O(L)의 연산량과 O(\log L)의 깊이를 통해 최적의 균형점을 달성한다.2
| 비교 항목 | 순차적 RNN (Sequential) | 트랜스포머 (Transformer) | Mamba (Parallel Scan) |
|---|---|---|---|
| 시간 복잡도 (Step/Depth) | O(L) (병렬화 불가) | O(1) (완전 병렬) | O(\log L) (로그 병렬) |
| 연산 복잡도 (Work) | O(L \cdot N^2) | O(L^2 \cdot D) | O(L \cdot N^2) |
| 메모리 복잡도 | O(L \cdot N) | O(L^2 + L \cdot D) | O(L \cdot N) |
| 장기 의존성 학습 | 어려움 (Vanishing Gradient) | 우수하나 비용 과다 | 우수하며 효율적 |
4. 하드웨어 인식(Hardware-Aware) 구현: 이론을 현실로
병렬 스캔 알고리즘 자체는 수학적으로 완벽하지만, 실제 GPU 하드웨어에서 이를 구현할 때는 **메모리 계층(Memory Hierarchy)**이라는 물리적 장벽에 부딪힌다. Mamba의 혁신성은 알고리즘뿐만 아니라, 이를 하드웨어 수준에서 최적화한 구현 전략에 있다.3
4.1 GPU 메모리 벽(Memory Wall)과 I/O 병목
최신 GPU(A100, H100 등)의 연산 속도(FLOPS)는 메모리 대역폭(Bandwidth)보다 훨씬 빠르게 발전해 왔다. 따라서 딥러닝 모델의 학습 속도는 실제 연산 시간보다는 데이터를 메모리에서 프로세서로 이동시키는 시간에 의해 결정되는 경우가 많다. 이를 ‘메모리 벽(Memory Wall)’ 문제라고 한다.9
- HBM (High Bandwidth Memory): GPU의 메인 메모리로, 용량은 크지만(40GB~80GB) 접근 속도가 상대적으로 느리다.
- SRAM (Static RAM / Shared Memory): GPU 코어에 인접한 고속 메모리로, 접근 속도가 매우 빠르지만 용량이 극히 제한적이다(SM당 100KB~200KB 수준).
Mamba의 SSM 상태 h는 크기가 (B, L, D, N)이다. 여기서 N(상태 차원)은 보통 16 이상이므로, 입력 x (B, L, D)보다 N배 더 많은 메모리를 차지한다. 만약 병렬 스캔의 중간 결과를 HBM에 모두 기록(Materialization)한다면, 폭증하는 I/O 트래픽으로 인해 병렬화의 이점이 모두 사라지게 된다.3
4.2 커널 퓨전(Kernel Fusion) 전략
Mamba는 **커널 퓨전(Kernel Fusion)**을 통해 이 문제를 해결한다. 핵심 원칙은 **“확장된 상태 h를 절대 HBM에 기록하지 않는다”**는 것이다.
Mamba의 selective_scan_cuda 커널은 다음과 같은 절차로 동작한다 14:
- 데이터 로드 (HBM \to SRAM):
- 입력 x, 매개변수 \Delta, A, B, C 등 필요한 데이터만 HBM에서 SRAM으로 읽어온다. 이는 O(B L D + D N)의 크기로, 전체 상태 크기에 비해 훨씬 작다.
- SRAM 내부 연산 (On-Chip Computing):
- 이산화(Discretization): SRAM에 로드된 파라미터를 사용하여 \bar{A}, \bar{B}를 계산한다.
- 병렬 스캔(Parallel Scan): 계산된 파라미터로 SRAM 내부에서 즉시 병렬 스캔을 수행한다. 이 과정에서 생성되는 중간 상태 h_t들은 레지스터나 SRAM 내에만 머무르며 외부로 나가지 않는다.
- 출력 투영(Output Projection): 스캔 결과인 h_t에 C를 곱하여 최종 출력 y_t를 계산한다.
- 결과 저장 (SRAM \to HBM):
- 차원이 축소된 출력 y_t (크기 B \times L \times D)만을 HBM에 기록한다.
이러한 퓨전 전략을 통해 Mamba는 메모리 I/O 요구량을 이론적 최소치로 낮추었으며, 이는 FlashAttention이 O(N^2) 어텐션 행렬을 HBM에 쓰지 않음으로써 속도를 높인 원리와 정확히 일치한다.6
4.3 재계산(Recomputation)을 통한 역전파 최적화
학습(Training) 단계에서는 역전파(Backpropagation)를 위해 순전파(Forward pass) 때 계산한 중간값들이 필요하다. 그러나 Mamba는 메모리 절약을 위해 중간 상태 h를 HBM에 저장하지 않았다.
이를 해결하기 위해 Mamba는 재계산(Recomputation) 기법을 도입했다.3 역전파 단계에서 필요한 중간 상태들을 HBM에서 불러오는 대신, SRAM 내부에서 순전파 연산을 다시 수행하여 값을 재생성한다.
- 전통적 관점: 연산을 두 번 하는 것은 비효율적이다.
- 하드웨어 인식 관점: 연산(Compute)은 매우 빠르고 저렴한 반면, 메모리 접근(Memory Access)은 느리고 비싸다. 따라서 메모리 접근을 줄이기 위해 연산을 더 수행하는 것이 전체 속도 면에서 훨씬 유리하다.
이 전략 덕분에 Mamba의 메모리 사용량은 시퀀스 길이에 대해 선형적으로 비례하며, 트랜스포머보다 훨씬 긴 시퀀스를 단일 GPU에서 학습할 수 있는 기반이 된다.
5. 성능 벤치마크 및 비교 분석
Mamba의 병렬 스캔 및 하드웨어 최적화가 실제 성능에 미치는 영향은 벤치마크 결과를 통해 명확히 드러난다.
5.1 시퀀스 길이에 따른 학습 속도 비교
Mamba와 FlashAttention-2가 적용된 트랜스포머의 학습 속도를 비교하면, 명확한 **교차점(Crossover Point)**이 존재한다.18
- 짧은 시퀀스 (L < 2K): 트랜스포머(FlashAttention-2)가 더 빠르거나 유사한 성능을 보인다. 짧은 시퀀스에서는 행렬 곱(Matmul) 연산이 지배적이어서 GPU의 Tensor Core를 극한으로 활용하는 트랜스포머가 유리하다. Mamba의 스캔 연산은 순차적 요소가 일부 남아있어 오버헤드가 발생한다.
- 중간 시퀀스 (2K \le L \le 16K): Mamba가 트랜스포머를 앞지르기 시작한다. 트랜스포머의 연산량이 L^2로 증가하는 반면, Mamba는 L에 비례하여 선형적으로 증가하기 때문이다.
- 긴 시퀀스 (L > 16K): 격차가 극적으로 벌어진다. 128K 이상의 길이에서 트랜스포머는 메모리 부족(OOM) 현상을 겪거나 속도가 급격히 저하되지만, Mamba는 안정적인 처리량(Throughput)을 유지한다.
5.2 실제 하드웨어 처리량 데이터
구체적인 실험 데이터(A100 80GB 및 MI250 GPU 기준)를 살펴보면 다음과 같은 경향성을 확인할 수 있다.18
| 시퀀스 길이 (Sequence Length) | Mamba Throughput (tokens/sec) | Transformer (FlashAttn-2) | 비고 |
|---|---|---|---|
| 4,096 | ~44,000 | ~42,000 | 성능 유사, Mamba 소폭 우세 |
| 8,192 | ~43,500 | ~35,000 | Mamba의 우위 시작 |
| 16,384 | ~43,000 | ~20,000 | Transformer 성능 급감 (O(L^2)) |
| 1M (1,000,000) | 작동 가능 (Stable) | OOM (Out of Memory) | Mamba의 독보적 영역 |
이 데이터는 Mamba의 병렬 스캔 구현이 이론적인 선형 복잡도를 실제 물리적 시간 단축으로 완벽하게 전이시켰음을 증명한다. 특히 1백만 토큰 길이의 시퀀스에서도 학습이 가능하다는 점은 DNA 분석이나 초고해상도 비디오 처리와 같은 새로운 도메인으로의 확장을 가능케 한다.7
5.3 수치적 안정성(Numerical Stability)
병렬 스캔은 수치적 안정성 측면에서도 중요한 고려사항을 가진다. 반복적인 곱셈 연산(A_j A_i)은 값이 기하급수적으로 커지거나(Exploding) 작아질(Vanishing) 위험이 있다. Mamba는 이를 방지하기 위해 **로그 공간(Log-space)**에서의 연산이나 정규화(Normalization) 기법을 스캔 알고리즘 내부에 통합하여 수치적 안정성을 확보한다.3 이는 특히 부동소수점 정밀도가 낮은 BF16이나 FP16 환경에서 긴 시퀀스를 학습할 때 필수적인 요소이다.
6. 결론: 효율성과 성능의 완전한 통합
Mamba 아키텍처에서 ’4.1 병렬 스캔을 통한 학습 효율성 확보’는 단순한 기술적 최적화의 범주를 넘어선다. 이는 **1) 수학적 알고리즘(결합 법칙과 스캔), 2) 시스템 아키텍처(선택적 SSM), 3) 하드웨어 엔지니어링(커널 퓨전과 메모리 계층 최적화)**이 삼위일체를 이루어낸 결과물이다.
Mamba는 RNN의 장점인 선형 추론 비용과 트랜스포머의 장점인 병렬 학습 능력을 동시에 달성함으로써, 시퀀스 모델링의 오랜 난제였던 ’효율성-성능 트레이드오프’를 근본적으로 해체했다. 이는 향후 등장할 모델들이 단순히 파라미터 수를 늘리는 경쟁에서 벗어나, 얼마나 효율적으로 긴 문맥을 처리하고 정보를 압축할 수 있는지에 대한 ‘알고리즘적 효율성’ 경쟁으로 나아가는 전환점이 될 것이다. 병렬 스캔은 그 새로운 시대의 핵심 엔진으로서, Mamba가 포스트 트랜스포머 시대의 주역으로 부상하게 하는 가장 강력한 동력원이다.
7. 참고 자료
- Mamba No. 5 (A Little Bit Of…) - Sparse Notes, https://jameschen.io/jekyll/update/2024/02/12/mamba.html
- How does Mamba reduce time complexity of Transformers? - Medium, https://medium.com/@colaglory/how-mamba-reduces-time-complexity-of-transformers-9eb350f5f368
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces, https://arxiv.org/pdf/2312.00752
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces, https://arxiv.org/html/2312.00752v2
- Mamba Selective State Space Model - Emergent Mind, https://www.emergentmind.com/topics/mamba-based-selective-state-space-model
- Here Comes Mamba: The Selective State Space Model, https://towardsdatascience.com/here-comes-mamba-the-selective-state-space-model-435e5d17a451/
- Mamba: Efficient Sequence Modeling with SSMs - Emergent Mind, https://www.emergentmind.com/papers/2312.00752
- BPPSA: Scaling Back-propagation by Parallel Scan Algorithm, https://www.cs.toronto.edu/ecosystem/papers/MLSys_20/BPPSA.pdf
- Learning Triton One Kernel At a Time: Vector Addition, https://towardsdatascience.com/learning-triton-one-kernel-at-a-time-vector-addition/
- Mamba-2: Algorithms and Systems, https://pli.princeton.edu/blog/2024/mamba-2-algorithms-and-systems
- AMS 148 Chapter 5: Reduce and Scan, https://ams148-spring18-01.courses.soe.ucsc.edu/system/files/attachments/note5.pdf
- Tri Dao, https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1244/slides/cs224n-2024-lecture18-deployment-and-efficiency.pdf
- Why isn’t selective_scan implemented in Triton? · Issue #35 - GitHub, https://github.com/state-spaces/mamba/issues/35
- A Visual Guide to Mamba and State Space Models, https://www.maartengrootendorst.com/blog/mamba/
- Mamba: Make Sequence Models Fast Again | by Dong-Keon Kim, https://medium.com/@kdk199604/mamba-make-sequence-models-fast-again-540245a49155
- Introducing Triton: Open-source GPU programming for neural …, https://openai.com/index/triton/
- Mamba Model Blog - HackMD, https://hackmd.io/Btjp7ZMRQGCLh93n1vMAVw
- Training Mamba Models on AMD MI250/MI250X GPUs with Custom …, https://www.lighton.ai/lighton-blogs/training-mamba-models-on-amd-mi250-mi250x-gpus-with-custom-kernels
- how to compare mamba with flashattention2 · Issue #27 - GitHub, https://github.com/state-spaces/mamba/issues/27
- Towards Scalable and Stable Parallelization of Nonlinear RNNs, https://arxiv.org/html/2407.19115v2